{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DjUA6S30k52h"
      },
      "source": [
        "##### Copyright 2021 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "SpNWyqewk8fE"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6x1ypzczQCwy"
      },
      "source": [
        "# Data validation using TFX Pipeline and TensorFlow Data Validation"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HU9YYythm0dx"
      },
      "source": [
        "Note: We recommend running this tutorial in a Colab notebook, with no setup required!  Just click \"Run in Google Colab\".\n",
        "\n",
        "<div class=\"devsite-table-wrapper\"><table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "<td><a target=\"_blank\" href=\"https://www.tensorflow.org/tfx/tutorials/tfx/penguin_tfdv\">\n",
        "<img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\"/>View on TensorFlow.org</a></td>\n",
        "<td><a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tfx/blob/master/docs/tutorials/tfx/penguin_tfdv.ipynb\">\n",
        "<img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\">Run in Google Colab</a></td>\n",
        "<td><a target=\"_blank\" href=\"https://github.com/tensorflow/tfx/tree/master/docs/tutorials/tfx/penguin_tfdv.ipynb\">\n",
        "<img width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\">View source on GitHub</a></td>\n",
        "<td><a href=\"https://storage.googleapis.com/tensorflow_docs/tfx/docs/tutorials/tfx/penguin_tfdv.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a></td>\n",
        "</table></div>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_VuwrlnvQJ5k"
      },
      "source": [
        "In this notebook-based tutorial, we will create and run TFX pipelines\n",
        "to validate input data and create an ML model. This notebook is based on the\n",
        "TFX pipeline we built in\n",
        "[Simple TFX Pipeline Tutorial](https://www.tensorflow.org/tfx/tutorials/tfx/penguin_simple).\n",
        "If you have not read that tutorial yet, you should read it before proceeding\n",
        "with this notebook.\n",
        "\n",
        "The first task in any data science or ML project is to understand and clean\n",
        "the data, which includes:\n",
        "- Understanding the data types, distributions, and other information (e.g.,\n",
        "mean value, or number of uniques) about each feature\n",
        "- Generating a preliminary schema that describes the data\n",
        "- Identifying anomalies and missing values in the data with respect to given\n",
        "schema\n",
        "\n",
        "In this tutorial, we will create two TFX pipelines.\n",
        "\n",
        "First, we will create a pipeline to analyze the dataset and generate a\n",
        "preliminary schema of the given dataset. This pipeline will include two new\n",
        "components, `StatisticsGen` and `SchemaGen`.\n",
        "\n",
        "Once we have a proper schema of the data, we will create a pipeline to train\n",
        "an ML classification model based on the pipeline from the previous tutorial.\n",
        "In this pipeline, we will use the schema from the first pipeline and a\n",
        "new component, `ExampleValidator`, to validate the input data.\n",
        "\n",
        "The three new components, StatisticsGen, SchemaGen and ExampleValidator, are\n",
        "TFX components for data analysis and validation, and they are implemented\n",
        "using the\n",
        "[TensorFlow Data Validation](https://www.tensorflow.org/tfx/guide/tfdv) library.\n",
        "\n",
        "Please see\n",
        "[Understanding TFX Pipelines](https://www.tensorflow.org/tfx/guide/understanding_tfx_pipelines)\n",
        "to learn more about various concepts in TFX."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Fmgi8ZvQkScg"
      },
      "source": [
        "## Set Up\n",
        "We first need to install the TFX Python package and download\n",
        "the dataset which we will use for our model.\n",
        "\n",
        "### Upgrade Pip\n",
        "\n",
        "To avoid upgrading Pip in a system when running locally,\n",
        "check to make sure that we are running in Colab.\n",
        "Local systems can of course be upgraded separately."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "as4OTe2ukSqm"
      },
      "outputs": [],
      "source": [
        "try:\n",
        "  import colab\n",
        "  !pip install --upgrade pip\n",
        "except:\n",
        "  pass"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MZOYTt1RW4TK"
      },
      "source": [
        "### Install TFX\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iyQtljP-qPHY"
      },
      "outputs": [],
      "source": [
        "!pip install -U tfx"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EwT0nov5QO1M"
      },
      "source": [
        "### Did you restart the runtime?\n",
        "\n",
        "If you are using Google Colab, the first time that you run\n",
        "the cell above, you must restart the runtime by clicking\n",
        "above \"RESTART RUNTIME\" button or using \"Runtime > Restart\n",
        "runtime ...\" menu. This is because of the way that Colab\n",
        "loads packages."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BDnPgN8UJtzN"
      },
      "source": [
        "Check the TensorFlow and TFX versions."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6jh7vKSRqPHb"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "print('TensorFlow version: {}'.format(tf.__version__))\n",
        "from tfx import v1 as tfx\n",
        "print('TFX version: {}'.format(tfx.__version__))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aDtLdSkvqPHe"
      },
      "source": [
        "### Set up variables\n",
        "\n",
        "There are some variables used to define a pipeline. You can customize these\n",
        "variables as you want. By default all output from the pipeline will be\n",
        "generated under the current directory."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EcUseqJaE2XN"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "\n",
        "# We will create two pipelines. One for schema generation and one for training.\n",
        "SCHEMA_PIPELINE_NAME = \"penguin-tfdv-schema\"\n",
        "PIPELINE_NAME = \"penguin-tfdv\"\n",
        "\n",
        "# Output directory to store artifacts generated from the pipeline.\n",
        "SCHEMA_PIPELINE_ROOT = os.path.join('pipelines', SCHEMA_PIPELINE_NAME)\n",
        "PIPELINE_ROOT = os.path.join('pipelines', PIPELINE_NAME)\n",
        "# Path to a SQLite DB file to use as an MLMD storage.\n",
        "SCHEMA_METADATA_PATH = os.path.join('metadata', SCHEMA_PIPELINE_NAME,\n",
        "                                    'metadata.db')\n",
        "METADATA_PATH = os.path.join('metadata', PIPELINE_NAME, 'metadata.db')\n",
        "\n",
        "# Output directory where created models from the pipeline will be exported.\n",
        "SERVING_MODEL_DIR = os.path.join('serving_model', PIPELINE_NAME)\n",
        "\n",
        "from absl import logging\n",
        "logging.set_verbosity(logging.INFO)  # Set default logging level."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qsO0l5F3dzOr"
      },
      "source": [
        "### Prepare example data\n",
        "We will download the example dataset for use in our TFX pipeline. The dataset\n",
        "we are using is\n",
        "[Palmer Penguins dataset](https://allisonhorst.github.io/palmerpenguins/articles/intro.html)\n",
        "which is also used in other\n",
        "[TFX examples](https://github.com/tensorflow/tfx/tree/master/tfx/examples/penguin).\n",
        "\n",
        "There are four numeric features in this dataset:\n",
        "\n",
        "- culmen_length_mm\n",
        "- culmen_depth_mm\n",
        "- flipper_length_mm\n",
        "- body_mass_g\n",
        "\n",
        "All features were already normalized to have range [0,1]. We will build a\n",
        "classification model which predicts the `species` of penguins."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IjE8MkZidzO0"
      },
      "source": [
        "Because the TFX ExampleGen component reads inputs from a directory, we need\n",
        "to create a directory and copy the dataset to it."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZSfs6qFgdzO1"
      },
      "outputs": [],
      "source": [
        "import urllib.request\n",
        "import tempfile\n",
        "\n",
        "DATA_ROOT = tempfile.mkdtemp(prefix='tfx-data')  # Create a temporary directory.\n",
        "_data_url = 'https://raw.githubusercontent.com/tensorflow/tfx/master/tfx/examples/penguin/data/labelled/penguins_processed.csv'\n",
        "_data_filepath = os.path.join(DATA_ROOT, \"data.csv\")\n",
        "urllib.request.urlretrieve(_data_url, _data_filepath)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "n5s3wGpndzO1"
      },
      "source": [
        "Take a quick look at the CSV file."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nLn9ith2dzO1"
      },
      "outputs": [],
      "source": [
        "!head {_data_filepath}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "z8EOfCy1dzO2"
      },
      "source": [
        "You should be able to see five feature columns. `species` is one of 0, 1 or 2,\n",
        "and all other features should have values between 0 and 1. We will create a TFX\n",
        "pipeline to analyze this dataset."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ePhfeYv0fVu1"
      },
      "source": [
        "## Generate a preliminary schema\n",
        "\n",
        "TFX pipelines are defined using Python APIs. We will create a pipeline to\n",
        "generate a schema from the input examples automatically. This schema can be\n",
        "reviewed by a human and adjusted as needed. Once the schema is finalized it can\n",
        "be used for training and example validation in later tasks.\n",
        "\n",
        "In addition to `CsvExampleGen` which is used in\n",
        "[Simple TFX Pipeline Tutorial](https://www.tensorflow.org/tfx/tutorials/tfx/penguin_simple),\n",
        "we will use `StatisticsGen` and `SchemaGen`:\n",
        "\n",
        "- [StatisticsGen](https://www.tensorflow.org/tfx/guide/statsgen) calculates\n",
        "statistics for the dataset.\n",
        "- [SchemaGen](https://www.tensorflow.org/tfx/guide/schemagen) examines the\n",
        "statistics and creates an initial data schema.\n",
        "\n",
        "See the guides for each component or\n",
        "[TFX components tutorial](https://www.tensorflow.org/tfx/tutorials/tfx/components_keras)\n",
        "to learn more on these components."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JUFq55kCgwsm"
      },
      "source": [
        "### Write a pipeline definition\n",
        "\n",
        "We define a function to create a TFX pipeline. A `Pipeline` object\n",
        "represents a TFX pipeline which can be run using one of pipeline\n",
        "orchestration systems that TFX supports."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GfQ6FAk9gxJ2"
      },
      "outputs": [],
      "source": [
        "def _create_schema_pipeline(pipeline_name: str,\n",
        "                            pipeline_root: str,\n",
        "                            data_root: str,\n",
        "                            metadata_path: str) -> tfx.dsl.Pipeline:\n",
        "  \"\"\"Creates a pipeline for schema generation.\"\"\"\n",
        "  # Brings data into the pipeline.\n",
        "  example_gen = tfx.components.CsvExampleGen(input_base=data_root)\n",
        "\n",
        "  # NEW: Computes statistics over data for visualization and schema generation.\n",
        "  statistics_gen = tfx.components.StatisticsGen(\n",
        "      examples=example_gen.outputs['examples'])\n",
        "\n",
        "  # NEW: Generates schema based on the generated statistics.\n",
        "  schema_gen = tfx.components.SchemaGen(\n",
        "      statistics=statistics_gen.outputs['statistics'], infer_feature_shape=True)\n",
        "\n",
        "  components = [\n",
        "      example_gen,\n",
        "      statistics_gen,\n",
        "      schema_gen,\n",
        "  ]\n",
        "\n",
        "  return tfx.dsl.Pipeline(\n",
        "      pipeline_name=pipeline_name,\n",
        "      pipeline_root=pipeline_root,\n",
        "      metadata_connection_config=tfx.orchestration.metadata\n",
        "      .sqlite_metadata_connection_config(metadata_path),\n",
        "      components=components)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RuKFLI_Og2xr"
      },
      "source": [
        "### Run the pipeline\n",
        "\n",
        "We will use `LocalDagRunner` as in the previous tutorial."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BQspf0ajg9AO"
      },
      "outputs": [],
      "source": [
        "tfx.orchestration.LocalDagRunner().run(\n",
        "  _create_schema_pipeline(\n",
        "      pipeline_name=SCHEMA_PIPELINE_NAME,\n",
        "      pipeline_root=SCHEMA_PIPELINE_ROOT,\n",
        "      data_root=DATA_ROOT,\n",
        "      metadata_path=SCHEMA_METADATA_PATH))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VD4LsLHBi2O4"
      },
      "source": [
        "You should see \"INFO:absl:Component SchemaGen is finished.\" if the pipeline\n",
        "finished successfully.\n",
        "\n",
        "We will examine the output of the pipeline to understand our dataset."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lWpckstgg9Zs"
      },
      "source": [
        "### Review outputs of the pipeline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tL1wWoDh5wkj"
      },
      "source": [
        "As explained in the previous tutorial, a TFX pipeline produces two kinds of\n",
        "outputs, artifacts and a\n",
        "[metadata DB(MLMD)](https://www.tensorflow.org/tfx/guide/mlmd) which contains\n",
        "metadata of artifacts and pipeline executions. We defined the location of \n",
        "these outputs in the above cells. By default, artifacts are stored under\n",
        "the `pipelines` directory and metadata is stored as a sqlite database\n",
        "under the `metadata` directory.\n",
        "\n",
        "You can use MLMD APIs to locate these outputs programatically. First, we will\n",
        "define some utility functions to search for the output artifacts that were just\n",
        "produced.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "K0i_jTvOI8mv"
      },
      "outputs": [],
      "source": [
        "from ml_metadata.proto import metadata_store_pb2\n",
        "# Non-public APIs, just for showcase.\n",
        "from tfx.orchestration.portable.mlmd import execution_lib\n",
        "\n",
        "# TODO(b/171447278): Move these functions into the TFX library.\n",
        "\n",
        "def get_latest_artifacts(metadata, pipeline_name, component_id):\n",
        "  \"\"\"Output artifacts of the latest run of the component.\"\"\"\n",
        "  context = metadata.store.get_context_by_type_and_name(\n",
        "      'node', f'{pipeline_name}.{component_id}')\n",
        "  executions = metadata.store.get_executions_by_context(context.id)\n",
        "  latest_execution = max(executions,\n",
        "                         key=lambda e:e.last_update_time_since_epoch)\n",
        "  return execution_lib.get_artifacts_dict(metadata, latest_execution.id, \n",
        "                                          metadata_store_pb2.Event.OUTPUT)\n",
        "\n",
        "# Non-public APIs, just for showcase.\n",
        "from tfx.orchestration.experimental.interactive import visualizations\n",
        "\n",
        "def visualize_artifacts(artifacts):\n",
        "  \"\"\"Visualizes artifacts using standard visualization modules.\"\"\"\n",
        "  for artifact in artifacts:\n",
        "    visualization = visualizations.get_registry().get_visualization(\n",
        "        artifact.type_name)\n",
        "    if visualization:\n",
        "      visualization.display(artifact)\n",
        "\n",
        "from tfx.orchestration.experimental.interactive import standard_visualizations\n",
        "standard_visualizations.register_standard_visualizations()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2CE1dk_3irPL"
      },
      "source": [
        "Now we can examine the outputs from the pipeline execution."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hRKSjXzsiqh0"
      },
      "outputs": [],
      "source": [
        "# Non-public APIs, just for showcase.\n",
        "from tfx.orchestration.metadata import Metadata\n",
        "from tfx.types import standard_component_specs\n",
        "\n",
        "metadata_connection_config = tfx.orchestration.metadata.sqlite_metadata_connection_config(\n",
        "    SCHEMA_METADATA_PATH)\n",
        "\n",
        "with Metadata(metadata_connection_config) as metadata_handler:\n",
        "  # Find output artifacts from MLMD.\n",
        "  stat_gen_output = get_latest_artifacts(metadata_handler, SCHEMA_PIPELINE_NAME,\n",
        "                                         'StatisticsGen')\n",
        "  stats_artifacts = stat_gen_output[standard_component_specs.STATISTICS_KEY]\n",
        "\n",
        "  schema_gen_output = get_latest_artifacts(metadata_handler,\n",
        "                                           SCHEMA_PIPELINE_NAME, 'SchemaGen')\n",
        "  schema_artifacts = schema_gen_output[standard_component_specs.SCHEMA_KEY]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9e8i0K-Aiqh-"
      },
      "source": [
        "It is time to examine the outputs from each component. As described above,\n",
        "[Tensorflow Data Validation(TFDV)](https://www.tensorflow.org/tfx/data_validation/get_started)\n",
        "is used in `StatisticsGen` and `SchemaGen`, and TFDV also\n",
        "provides visualization of the outputs from these components.\n",
        "\n",
        "In this tutorial, we will use the visualization helper methods in TFX which\n",
        "use TFDV internally to show the visualization."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GRGC4X1Ziqh-"
      },
      "source": [
        "#### Examine the output from StatisticsGen\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3StnKm04iqh-"
      },
      "outputs": [],
      "source": [
        "visualize_artifacts(stats_artifacts)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yS1XXFtfiqh-"
      },
      "source": [
        "You can see various stats for the input data. These statistics are supplied to\n",
        "`SchemaGen` to construct an initial schema of data automatically.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "20HK9JS7iqh-"
      },
      "source": [
        "#### Examine the output from SchemaGen\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MVmlot5ziqh_"
      },
      "outputs": [],
      "source": [
        "visualize_artifacts(schema_artifacts)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8ldXsv2iiqh_"
      },
      "source": [
        "This schema is automatically inferred from the output of StatisticsGen. You\n",
        "should be able to see 4 FLOAT features and 1 INT feature."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bKpFPwEWhCoB"
      },
      "source": [
        "### Export the schema for future use\n",
        "\n",
        "We need to review and refine the generated schema. The reviewed schema needs\n",
        "to be persisted to be used in subsequent pipelines for ML model training. In\n",
        "other words, you might want to add the schema file to your version control\n",
        "system for actual use cases. In this tutorial, we will just copy the schema\n",
        "to a predefined filesystem path for simplicity.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0Pyi0oaKmRTg"
      },
      "outputs": [],
      "source": [
        "import shutil\n",
        "\n",
        "_schema_filename = 'schema.pbtxt'\n",
        "SCHEMA_PATH = 'schema'\n",
        "\n",
        "os.makedirs(SCHEMA_PATH, exist_ok=True)\n",
        "_generated_path = os.path.join(schema_artifacts[0].uri, _schema_filename)\n",
        "\n",
        "# Copy the 'schema.pbtxt' file from the artifact uri to a predefined path.\n",
        "shutil.copy(_generated_path, SCHEMA_PATH)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "05U8uQ6dnlB4"
      },
      "source": [
        "The schema file uses\n",
        "[Protocol Buffer text format](https://googleapis.dev/python/protobuf/latest/google/protobuf/text_format.html)\n",
        "and an instance of\n",
        "[TensorFlow Metadata Schema proto](https://github.com/tensorflow/metadata/blob/master/tensorflow_metadata/proto/v0/schema.proto)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uwHO7-HfnlWs"
      },
      "outputs": [],
      "source": [
        "print(f'Schema at {SCHEMA_PATH}-----')\n",
        "!cat {SCHEMA_PATH}/*"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BjKigLTNos4F"
      },
      "source": [
        "You should be sure to review and possibly edit the schema definition as\n",
        "needed. In this tutorial, we will just use the generated schema unchanged.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nH6gizcpSwWV"
      },
      "source": [
        "## Validate input examples and train an ML model\n",
        "\n",
        "We will go back to the pipeline that we created in\n",
        "[Simple TFX Pipeline Tutorial](https://www.tensorflow.org/tfx/tutorials/tfx/penguin_simple),\n",
        "to train an ML model and use the generated schema for writing the model\n",
        "training code.\n",
        "\n",
        "We will also add an\n",
        "[ExampleValidator](https://www.tensorflow.org/tfx/guide/exampleval)\n",
        "component which will look for anomalies and missing values in the incoming\n",
        "dataset with respect to the schema.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lOjDv93eS5xV"
      },
      "source": [
        "### Write model training code\n",
        "\n",
        "We need to write the model code as we did in\n",
        "[Simple TFX Pipeline Tutorial](https://www.tensorflow.org/tfx/tutorials/tfx/penguin_simple).\n",
        "\n",
        "The model itself is the same as in the previous tutorial, but this time we will\n",
        "use the schema generated from the previous pipeline instead of specifying\n",
        "features manually. Most of the code was not changed. The only difference is\n",
        "that we do not need to specify the names and types of features in this file.\n",
        "Instead, we read them from the *schema* file."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "aES7Hv5QTDK3"
      },
      "outputs": [],
      "source": [
        "_trainer_module_file = 'penguin_trainer.py'"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Gnc67uQNTDfW"
      },
      "outputs": [],
      "source": [
        "%%writefile {_trainer_module_file}\n",
        "\n",
        "from typing import List\n",
        "from absl import logging\n",
        "import tensorflow as tf\n",
        "from tensorflow import keras\n",
        "from tensorflow_transform.tf_metadata import schema_utils\n",
        "\n",
        "from tfx import v1 as tfx\n",
        "from tfx_bsl.public import tfxio\n",
        "from tensorflow_metadata.proto.v0 import schema_pb2\n",
        "\n",
        "# We don't need to specify _FEATURE_KEYS and _FEATURE_SPEC any more.\n",
        "# Those information can be read from the given schema file.\n",
        "\n",
        "_LABEL_KEY = 'species'\n",
        "\n",
        "_TRAIN_BATCH_SIZE = 20\n",
        "_EVAL_BATCH_SIZE = 10\n",
        "\n",
        "def _input_fn(file_pattern: List[str],\n",
        "              data_accessor: tfx.components.DataAccessor,\n",
        "              schema: schema_pb2.Schema,\n",
        "              batch_size: int = 200) -> tf.data.Dataset:\n",
        "  \"\"\"Generates features and label for training.\n",
        "\n",
        "  Args:\n",
        "    file_pattern: List of paths or patterns of input tfrecord files.\n",
        "    data_accessor: DataAccessor for converting input to RecordBatch.\n",
        "    schema: schema of the input data.\n",
        "    batch_size: representing the number of consecutive elements of returned\n",
        "      dataset to combine in a single batch\n",
        "\n",
        "  Returns:\n",
        "    A dataset that contains (features, indices) tuple where features is a\n",
        "      dictionary of Tensors, and indices is a single Tensor of label indices.\n",
        "  \"\"\"\n",
        "  return data_accessor.tf_dataset_factory(\n",
        "      file_pattern,\n",
        "      tfxio.TensorFlowDatasetOptions(\n",
        "          batch_size=batch_size, label_key=_LABEL_KEY),\n",
        "      schema=schema).repeat()\n",
        "\n",
        "\n",
        "def _build_keras_model(schema: schema_pb2.Schema) -> tf.keras.Model:\n",
        "  \"\"\"Creates a DNN Keras model for classifying penguin data.\n",
        "\n",
        "  Returns:\n",
        "    A Keras Model.\n",
        "  \"\"\"\n",
        "  # The model below is built with Functional API, please refer to\n",
        "  # https://www.tensorflow.org/guide/keras/overview for all API options.\n",
        "\n",
        "  # ++ Changed code: Uses all features in the schema except the label.\n",
        "  feature_keys = [f.name for f in schema.feature if f.name != _LABEL_KEY]\n",
        "  inputs = [keras.layers.Input(shape=(1,), name=f) for f in feature_keys]\n",
        "  # ++ End of the changed code.\n",
        "\n",
        "  d = keras.layers.concatenate(inputs)\n",
        "  for _ in range(2):\n",
        "    d = keras.layers.Dense(8, activation='relu')(d)\n",
        "  outputs = keras.layers.Dense(3)(d)\n",
        "\n",
        "  model = keras.Model(inputs=inputs, outputs=outputs)\n",
        "  model.compile(\n",
        "      optimizer=keras.optimizers.Adam(1e-2),\n",
        "      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
        "      metrics=[keras.metrics.SparseCategoricalAccuracy()])\n",
        "\n",
        "  model.summary(print_fn=logging.info)\n",
        "  return model\n",
        "\n",
        "\n",
        "# TFX Trainer will call this function.\n",
        "def run_fn(fn_args: tfx.components.FnArgs):\n",
        "  \"\"\"Train the model based on given args.\n",
        "\n",
        "  Args:\n",
        "    fn_args: Holds args used to train the model as name/value pairs.\n",
        "  \"\"\"\n",
        "\n",
        "  # ++ Changed code: Reads in schema file passed to the Trainer component.\n",
        "  schema = tfx.utils.parse_pbtxt_file(fn_args.schema_path, schema_pb2.Schema())\n",
        "  # ++ End of the changed code.\n",
        "\n",
        "  train_dataset = _input_fn(\n",
        "      fn_args.train_files,\n",
        "      fn_args.data_accessor,\n",
        "      schema,\n",
        "      batch_size=_TRAIN_BATCH_SIZE)\n",
        "  eval_dataset = _input_fn(\n",
        "      fn_args.eval_files,\n",
        "      fn_args.data_accessor,\n",
        "      schema,\n",
        "      batch_size=_EVAL_BATCH_SIZE)\n",
        "\n",
        "  model = _build_keras_model(schema)\n",
        "  model.fit(\n",
        "      train_dataset,\n",
        "      steps_per_epoch=fn_args.train_steps,\n",
        "      validation_data=eval_dataset,\n",
        "      validation_steps=fn_args.eval_steps)\n",
        "\n",
        "  # The result of the training should be saved in `fn_args.serving_model_dir`\n",
        "  # directory.\n",
        "  model.save(fn_args.serving_model_dir, save_format='tf')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "blaw0rs-emEf"
      },
      "source": [
        "Now you have completed all preparation steps to build a TFX pipeline for\n",
        "model training."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "w3OkNz3gTLwM"
      },
      "source": [
        "### Write a pipeline definition\n",
        "\n",
        "We will add two new components, `Importer` and `ExampleValidator`. Importer\n",
        "brings an external file into the TFX pipeline. In this case, it is a file\n",
        "containing schema definition. ExampleValidator will examine\n",
        "the input data and validate whether all input data conforms the data schema\n",
        "we provided.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "M49yYVNBTPd4"
      },
      "outputs": [],
      "source": [
        "def _create_pipeline(pipeline_name: str, pipeline_root: str, data_root: str,\n",
        "                     schema_path: str, module_file: str, serving_model_dir: str,\n",
        "                     metadata_path: str) -> tfx.dsl.Pipeline:\n",
        "  \"\"\"Creates a pipeline using predefined schema with TFX.\"\"\"\n",
        "  # Brings data into the pipeline.\n",
        "  example_gen = tfx.components.CsvExampleGen(input_base=data_root)\n",
        "\n",
        "  # Computes statistics over data for visualization and example validation.\n",
        "  statistics_gen = tfx.components.StatisticsGen(\n",
        "      examples=example_gen.outputs['examples'])\n",
        "\n",
        "  # NEW: Import the schema.\n",
        "  schema_importer = tfx.dsl.Importer(\n",
        "      source_uri=schema_path,\n",
        "      artifact_type=tfx.types.standard_artifacts.Schema).with_id(\n",
        "          'schema_importer')\n",
        "\n",
        "  # NEW: Performs anomaly detection based on statistics and data schema.\n",
        "  example_validator = tfx.components.ExampleValidator(\n",
        "      statistics=statistics_gen.outputs['statistics'],\n",
        "      schema=schema_importer.outputs['result'])\n",
        "\n",
        "  # Uses user-provided Python function that trains a model.\n",
        "  trainer = tfx.components.Trainer(\n",
        "      module_file=module_file,\n",
        "      examples=example_gen.outputs['examples'],\n",
        "      schema=schema_importer.outputs['result'],  # Pass the imported schema.\n",
        "      train_args=tfx.proto.TrainArgs(num_steps=100),\n",
        "      eval_args=tfx.proto.EvalArgs(num_steps=5))\n",
        "\n",
        "  # Pushes the model to a filesystem destination.\n",
        "  pusher = tfx.components.Pusher(\n",
        "      model=trainer.outputs['model'],\n",
        "      push_destination=tfx.proto.PushDestination(\n",
        "          filesystem=tfx.proto.PushDestination.Filesystem(\n",
        "              base_directory=serving_model_dir)))\n",
        "\n",
        "  components = [\n",
        "      example_gen,\n",
        "\n",
        "      # NEW: Following three components were added to the pipeline.\n",
        "      statistics_gen,\n",
        "      schema_importer,\n",
        "      example_validator,\n",
        "\n",
        "      trainer,\n",
        "      pusher,\n",
        "  ]\n",
        "\n",
        "  return tfx.dsl.Pipeline(\n",
        "      pipeline_name=pipeline_name,\n",
        "      pipeline_root=pipeline_root,\n",
        "      metadata_connection_config=tfx.orchestration.metadata\n",
        "      .sqlite_metadata_connection_config(metadata_path),\n",
        "      components=components)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mJbq07THU2GV"
      },
      "source": [
        "### Run the pipeline\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fAtfOZTYWJu-"
      },
      "outputs": [],
      "source": [
        "tfx.orchestration.LocalDagRunner().run(\n",
        "  _create_pipeline(\n",
        "      pipeline_name=PIPELINE_NAME,\n",
        "      pipeline_root=PIPELINE_ROOT,\n",
        "      data_root=DATA_ROOT,\n",
        "      schema_path=SCHEMA_PATH,\n",
        "      module_file=_trainer_module_file,\n",
        "      serving_model_dir=SERVING_MODEL_DIR,\n",
        "      metadata_path=METADATA_PATH))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AZ3nTzG8uAzn"
      },
      "source": [
        "You should see \"INFO:absl:Component Pusher is finished.\" if the pipeline\n",
        "finished successfully."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uuD5FRPAcOn8"
      },
      "source": [
        "### Examine outputs of the pipeline\n",
        "\n",
        "We have trained the classification model for penguins, and we also have\n",
        "validated the input examples in the ExampleValidator component. We can analyze\n",
        "the output from ExampleValidator as we did with the previous pipeline."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TtsrZEUB1-J4"
      },
      "outputs": [],
      "source": [
        "metadata_connection_config = tfx.orchestration.metadata.sqlite_metadata_connection_config(\n",
        "    METADATA_PATH)\n",
        "\n",
        "with Metadata(metadata_connection_config) as metadata_handler:\n",
        "  ev_output = get_latest_artifacts(metadata_handler, PIPELINE_NAME,\n",
        "                                   'ExampleValidator')\n",
        "  anomalies_artifacts = ev_output[standard_component_specs.ANOMALIES_KEY]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3U5MNAUIdBtN"
      },
      "source": [
        "ExampleAnomalies from the ExampleValidator can be visualized as well."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "F-4oAjGR-IR0"
      },
      "outputs": [],
      "source": [
        "visualize_artifacts(anomalies_artifacts)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "t026ZzbU0961"
      },
      "source": [
        "You should see \"No anomalies found\" for each split of examples. Because we\n",
        "used the same data which was used for the schema generation in this pipeline,\n",
        "no anomaly is expected here. If you run this pipeline repeatedly with new\n",
        "incoming data, ExampleValidator should be able to find any discrepancies\n",
        "between the new data and the existing schema.\n",
        "\n",
        "If any anomalies were found, you may review your data to check to see if any\n",
        "examples do not follow your assumptions. Outputs from other components like\n",
        "StatisticsGen might be useful. However, any anomalies which are found will\n",
        "NOT block further pipeline executions."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "08R8qvweThRf"
      },
      "source": [
        "## Next steps\n",
        "\n",
        "You can find more resources on https://www.tensorflow.org/tfx/tutorials.\n",
        "\n",
        "Please see\n",
        "[Understanding TFX Pipelines](https://www.tensorflow.org/tfx/guide/understanding_tfx_pipelines)\n",
        "to learn more about various concepts in TFX.\n"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [
        "DjUA6S30k52h"
      ],
      "name": "penguin_tfdv.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
